{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[View the runnable example on GitHub](https://github.com/intel-analytics/BigDL/tree/main/python/nano/tutorial/notebook/inference/pytorch/inference_optimizer_optimize.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Find Acceleration Method with the Minimum Inference Latency using InferenceOptimizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This example illustrates how to apply InferenceOptimizer to quickly find acceleration method with the minimum inference latency under specific restrictions or without restrictions for a trained model. \n",
    "In this example, we first train ResNet18 model on the [cats and dogs dataset](https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip). Then, by calling `optimize()`, we can obtain all available accelaration combinations provided by BigDL-Nano for inference. By calling `get_best_model()` , we could get the best model under specific restrictions or without restrictions."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "nbsphinx": "hidden"
   },
   "source": [
    "To do inference using BigDL-Nano InferenceOptimizer, you need to install BigDL-Nano for PyTorch inference first. We recommend you to use [Miniconda](https://docs.conda.io/en/latest/miniconda.html) to prepare the environment and install the following packages in a conda environment. \n",
    "\n",
    "You can create a conda environment by executing:\n",
    "\n",
    "```bash\n",
    "# \"nano\" is conda environment name, you can use any name you like.\n",
    "conda create -n nano python=3.7 setuptools=58.0.4  \n",
    "conda activate nano\n",
    "pip install --pre --upgrade bigdl-nano[pytorch,inference]  # install the nightly-built version\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nbsphinx": "hidden"
   },
   "source": [
    "Then initialize environment variables with script `bigdl-nano-init` installed with bigdl-nano.\n",
    "\n",
    "```bash\n",
    "source bigdl-nano-init\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, prepare model and dataset. We use a pretrained ResNet18 model and train the model on [cats and dogs dataset](https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip) in this example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from pathlib import Path\n",
    "from torchmetrics.functional.classification.accuracy import multiclass_accuracy\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import ImageFolder\n",
    "from torchvision.datasets.utils import download_and_extract_archive\n",
    "from torch.utils.data import Subset, DataLoader\n",
    "from bigdl.nano.pytorch import Trainer\n",
    "\n",
    "def accuracy(pred, target):\n",
    "    pred = torch.sigmoid(pred)\n",
    "    return multiclass_accuracy(pred, target, num_classes=2)\n",
    "\n",
    "def prepare_model_and_dataset(model_ft, val_size):\n",
    "    DATA_URL = \"https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip\"\n",
    "\n",
    "    train_transform = transforms.Compose([\n",
    "        transforms.Resize((224, 224)),\n",
    "        transforms.RandomHorizontalFlip(),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
    "    ])\n",
    "\n",
    "    val_transform = transforms.Compose([\n",
    "        transforms.Resize((224, 224)),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
    "    ])\n",
    "\n",
    "    if not Path(\"data\").exists():\n",
    "        # download dataset\n",
    "        download_and_extract_archive(url=DATA_URL, download_root=\"data\", remove_finished=True)\n",
    "\n",
    "    data_path = Path(\"data/cats_and_dogs_filtered\")\n",
    "    train_dataset = ImageFolder(data_path.joinpath(\"train\"), transform=train_transform)\n",
    "    val_dataset = ImageFolder(data_path.joinpath(\"validation\"), transform=val_transform)\n",
    "\n",
    "    indices = torch.randperm(len(val_dataset))\n",
    "    val_dataset = Subset(val_dataset, indices=indices[:val_size])\n",
    "\n",
    "    train_dataloader = DataLoader(dataset=train_dataset, batch_size=8, shuffle=True)\n",
    "    val_dataloader = DataLoader(dataset=val_dataset, batch_size=8, shuffle=False)\n",
    "\n",
    "    num_ftrs = model_ft.fc.in_features\n",
    "    \n",
    "    model_ft.fc = torch.nn.Linear(num_ftrs, 2)\n",
    "    loss_ft = torch.nn.CrossEntropyLoss()\n",
    "    optimizer_ft = torch.optim.Adam(model_ft.parameters(), lr=1e-3)\n",
    "\n",
    "    # compile model\n",
    "    model = Trainer.compile(model_ft, loss=loss_ft, optimizer=optimizer_ft, metrics=[accuracy])\n",
    "    trainer = Trainer(max_epochs=1)\n",
    "    trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)\n",
    "    \n",
    "    return model, train_dataset, val_dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models import resnet18\n",
    "\n",
    "model = resnet18(pretrained=True)\n",
    "_, train_dataset, val_dataset = prepare_model_and_dataset(model, val_size=500)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;_The full definition of function_ `prepare_model_and_dataset` _could be found in the_ [runnable example](https://github.com/intel-analytics/BigDL/tree/main/python/nano/tutorial/notebook/inference/pytorch/inference_optimizer_optimize.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Obtain available accelaration combinations by `optimize`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Default search mode\n",
    "To find acceleration method with the minimum inference latency, you could import `InferenceOptimizer` and call `optimize` method. The `optimize` method will run all possible acceleration combinations and output the result, it will take about 1 minute."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bigdl.nano.pytorch import InferenceOptimizer\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# Define metric for accuracy calculation\n",
    "def accuracy(pred, target):\n",
    "    pred = torch.sigmoid(pred)\n",
    "    return multiclass_accuracy(pred, target, num_classes=2)\n",
    "\n",
    "optimizer = InferenceOptimizer()\n",
    "\n",
    "# To obtain the latency of single sample, set batch_size=1\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=1)\n",
    "val_dataloader = DataLoader(val_dataset)\n",
    "\n",
    "optimizer.optimize(model=model,\n",
    "                   training_data=train_dataloader,\n",
    "                   validation_data=val_dataloader,\n",
    "                   metric=accuracy,\n",
    "                   direction=\"max\",\n",
    "                   thread_num=1,\n",
    "                   latency_sample_num=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The example output of `optimizer.optimize` is shown below.\n",
    "\n",
    "```bash\n",
    " -------------------------------- ---------------------- -------------- ----------------------\n",
    "|             method             |        status        | latency(ms)  |     metric value     |\n",
    " -------------------------------- ---------------------- -------------- ----------------------\n",
    "|            original            |      successful      |    29.796    |        0.794         |\n",
    "|              bf16              |      successful      |    16.853    |        0.794         |\n",
    "|          static_int8           |      successful      |    12.149    |        0.786         |\n",
    "|         jit_fp32_ipex          |      successful      |    18.647    |        0.794*        |\n",
    "|  jit_fp32_ipex_channels_last   |      successful      |    21.505    |        0.794*        |\n",
    "|         jit_bf16_ipex          |      successful      |     9.7      |        0.792         |\n",
    "|  jit_bf16_ipex_channels_last   |      successful      |     9.84     |        0.792         |\n",
    "|         openvino_fp32          |      successful      |    24.205    |        0.794*        |\n",
    "|         openvino_int8          |      successful      |    5.805     |        0.792         |\n",
    "|        onnxruntime_fp32        |      successful      |    19.792    |        0.794*        |\n",
    "|    onnxruntime_int8_qlinear    |      successful      |     7.34     |         0.79         |\n",
    " -------------------------------- ---------------------- -------------- ----------------------\n",
    "* means we assume the metric value of the traced model does not change, so we don't recompute metric value to save time.\n",
    "Optimization cost 94.2s in total.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📝 **Note**\n",
    "> \n",
    "> When specifying `training_data` parameter, make sure to set batch size of the training data to the same batch size you may want to use in real deploy environment, as the batch size may impact on latency.\n",
    ">\n",
    "> For more information, please refer to the [API Documentation](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Nano/pytorch.html#bigdl.nano.pytorch.InferenceOptimizer)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. All search mode\n",
    "When calling `optimize`, to make sure the runnng time is not too long, as shown in the above table, by default, we only iterate 10 acceleration methods that we think are generally good. However currently we have 22 acceleration methods in all, if you want to get the global optimal acceleration model, you can specify `search_mode=all` when calling `optimize`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer.optimize(model=model,\n",
    "                   training_data=train_dataloader,\n",
    "                   validation_data=val_dataloader,\n",
    "                   metric=accuracy,\n",
    "                   direction=\"max\",\n",
    "                   thread_num=1,\n",
    "                   search_mode='all',\n",
    "                   latency_sample_num=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The example output of `optimizer.optimize` is shown below.\n",
    "\n",
    "```bash\n",
    " -------------------------------- ---------------------- -------------- ----------------------\n",
    "|             method             |        status        | latency(ms)  |     metric value     |\n",
    " -------------------------------- ---------------------- -------------- ----------------------\n",
    "|            original            |      successful      |    30.457    |        0.794         |\n",
    "|       fp32_channels_last       |      successful      |    28.973    |        0.794*        |\n",
    "|           fp32_ipex            |      successful      |    22.663    |        0.794*        |\n",
    "|    fp32_ipex_channels_last     |      successful      |    22.669    |        0.794*        |\n",
    "|              bf16              |      successful      |    17.378    |        0.794         |\n",
    "|       bf16_channels_last       |      successful      |    17.207    |        0.794         |\n",
    "|           bf16_ipex            |      successful      |    12.634    |        0.792         |\n",
    "|    bf16_ipex_channels_last     |      successful      |    13.36     |        0.792         |\n",
    "|          static_int8           |      successful      |    12.317    |        0.786         |\n",
    "|        static_int8_ipex        |   fail to convert    |     None     |         None         |\n",
    "|            jit_fp32            |      successful      |    18.114    |        0.794*        |\n",
    "|     jit_fp32_channels_last     |      successful      |    18.434    |        0.794*        |\n",
    "|            jit_bf16            |      successful      |    28.988    |        0.794         |\n",
    "|     jit_bf16_channels_last     |      successful      |    28.907    |        0.794         |\n",
    "|         jit_fp32_ipex          |      successful      |    18.021    |        0.794*        |\n",
    "|  jit_fp32_ipex_channels_last   |      successful      |    18.088    |        0.794*        |\n",
    "|         jit_bf16_ipex          |      successful      |    9.838     |        0.792         |\n",
    "|  jit_bf16_ipex_channels_last   |      successful      |    10.315    |        0.792         |\n",
    "|         openvino_fp32          |      successful      |    24.521    |        0.794*        |\n",
    "|         openvino_int8          |      successful      |    5.774     |        0.794         |\n",
    "|        onnxruntime_fp32        |      successful      |    19.682    |        0.794*        |\n",
    "|    onnxruntime_int8_qlinear    |      successful      |    7.726     |         0.79         |\n",
    "|    onnxruntime_int8_integer    |   fail to convert    |     None     |         None         |\n",
    " -------------------------------- ---------------------- -------------- ----------------------\n",
    "* means we assume the metric value of the traced model does not change, so we don't recompute metric value to save time.\n",
    "Optimization cost 152.7s in total.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Filter acceleration methods\n",
    "In some cases, you may just want to test or compare several specific methods, there are two ways to achieve this.\n",
    "\n",
    "1. If you just want to test very little methods, you could just set `includes` parameter:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer.optimize(model=model,\n",
    "                   training_data=train_dataloader,\n",
    "                   validation_data=val_dataloader,\n",
    "                   metric=accuracy,\n",
    "                   direction=\"max\",\n",
    "                   thread_num=1,\n",
    "                   includes=[\"openvino_fp32\", \"onnxruntime_fp32\"],\n",
    "                   latency_sample_num=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The example output of `optimizer.optimize` is shown below.\n",
    "\n",
    "```bash\n",
    " -------------------------------- ---------------------- -------------- ----------------------\n",
    "|             method             |        status        | latency(ms)  |     metric value     |\n",
    " -------------------------------- ---------------------- -------------- ----------------------\n",
    "|            original            |      successful      |    29.859    |        0.794         |\n",
    "|         openvino_fp32          |      successful      |    24.334    |        0.794*        |\n",
    "|        onnxruntime_fp32        |      successful      |    20.872    |        0.794*        |\n",
    " -------------------------------- ---------------------- -------------- ----------------------\n",
    "* means we assume the metric value of the traced model does not change, so we don't recompute metric value to save time.\n",
    "Optimization cost 22.8s in total.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. If you want to test methods with specific precision / accelerator, or you want to test methods with / without `ipex`, you could specify `precision` / `accelerator` / `use_ipex` parameter:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer.optimize(model=model,\n",
    "                   training_data=train_dataloader,\n",
    "                   validation_data=val_dataloader,\n",
    "                   metric=accuracy,\n",
    "                   direction=\"max\",\n",
    "                   thread_num=1,\n",
    "                   accelerator=('openvino', 'jit', None),\n",
    "                   precision=('fp32', 'bf16'),\n",
    "                   use_ipex=False,\n",
    "                   latency_sample_num=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The example output of `optimizer.optimize` is shown below.\n",
    "\n",
    "```bash\n",
    " -------------------------------- ---------------------- -------------- ----------------------\n",
    "|             method             |        status        | latency(ms)  |     metric value     |\n",
    " -------------------------------- ---------------------- -------------- ----------------------\n",
    "|            original            |      successful      |    30.978    |        0.794         |\n",
    "|       fp32_channels_last       |      successful      |    29.663    |        0.794*        |\n",
    "|              bf16              |      successful      |    17.12     |        0.794         |\n",
    "|       bf16_channels_last       |      successful      |    17.709    |        0.794         |\n",
    "|            jit_fp32            |      successful      |    18.411    |        0.794*        |\n",
    "|     jit_fp32_channels_last     |      successful      |    18.872    |        0.794*        |\n",
    "|            jit_bf16            |      successful      |    29.355    |        0.794         |\n",
    "|     jit_bf16_channels_last     |      successful      |    29.236    |        0.794         |\n",
    "|         openvino_fp32          |      successful      |    24.312    |        0.794*        |\n",
    " -------------------------------- ---------------------- -------------- ----------------------\n",
    "* means we assume the metric value of the traced model does not change, so we don't recompute metric value to save time.\n",
    "Optimization cost 60.8s in total.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📝 **Note**\n",
    "> \n",
    "> You must pass a tuple input for parameter `accelerator` / `precision`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In some cases, if you expect that some acceleration methods will not work for your model / not work well / run for too long / cause exceptions to the program, you could avoid running these methods by specifying `excludes` paramater:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer.optimize(model=model,\n",
    "                   training_data=train_dataloader,\n",
    "                   validation_data=val_dataloader,\n",
    "                   metric=accuracy,\n",
    "                   direction=\"max\",\n",
    "                   thread_num=1,\n",
    "                   excludes=[\"onnxruntime_int8_qlinear\", \"openvino_int8\"],\n",
    "                   latency_sample_num=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The example output of `optimizer.optimize` is shown below.\n",
    "\n",
    "```bash\n",
    " -------------------------------- ---------------------- -------------- ----------------------\n",
    "|             method             |        status        | latency(ms)  |     metric value     |\n",
    " -------------------------------- ---------------------- -------------- ----------------------\n",
    "|            original            |      successful      |    31.872    |        0.794         |\n",
    "|              bf16              |      successful      |    17.326    |        0.794         |\n",
    "|          static_int8           |      successful      |    12.39     |        0.786         |\n",
    "|         jit_fp32_ipex          |      successful      |    18.871    |        0.794*        |\n",
    "|  jit_fp32_ipex_channels_last   |      successful      |    18.453    |        0.794*        |\n",
    "|         jit_bf16_ipex          |      successful      |    9.863     |        0.792         |\n",
    "|  jit_bf16_ipex_channels_last   |      successful      |    9.871     |        0.792         |\n",
    "|         openvino_fp32          |      successful      |    24.585    |        0.794*        |\n",
    "|        onnxruntime_fp32        |      successful      |    19.452    |        0.794*        |\n",
    " -------------------------------- ---------------------- -------------- ----------------------\n",
    "* means we assume the metric value of the traced model does not change, so we don't recompute metric value to save time.\n",
    "Optimization cost 53.2s in total.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Disable validation during optimization\n",
    "\n",
    "If you can't get corresponding validation dataloader for you model, or you don't care about the possible accuracy drop, you could omit `validation_data`, `metric`, `direction` paramaters to disable validation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer.optimize(model=model,\n",
    "                   training_data=train_dataloader,\n",
    "                   thread_num=1,\n",
    "                   latency_sample_num=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The example output of `optimizer.optimize` is shown below.\n",
    "\n",
    "```bash\n",
    " -------------------------------- ---------------------- --------------\n",
    "|             method             |        status        | latency(ms)  |\n",
    " -------------------------------- ---------------------- --------------\n",
    "|            original            |      successful      |    29.387    |\n",
    "|              bf16              |      successful      |    16.657    |\n",
    "|          static_int8           |      successful      |    12.323    |\n",
    "|         jit_fp32_ipex          |      successful      |    18.645    |\n",
    "|  jit_fp32_ipex_channels_last   |      successful      |    18.478    |\n",
    "|         jit_bf16_ipex          |      successful      |    9.964     |\n",
    "|  jit_bf16_ipex_channels_last   |      successful      |    9.993     |\n",
    "|         openvino_fp32          |      successful      |    23.547    |\n",
    "|         openvino_int8          |      successful      |    5.711     |\n",
    "|        onnxruntime_fp32        |      successful      |    20.283    |\n",
    "|    onnxruntime_int8_qlinear    |      successful      |    7.141     |\n",
    " -------------------------------- ---------------------- --------------\n",
    "Optimization cost 49.9s in total.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5. More flexible input format\n",
    "Now that, `optimize` can not only accept `Dataloader`, but also accept `Tensor` or tuple of `Tensor` as input, as we will automatic turn them into `Dataloader` internally."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📝 **Note**\n",
    "> \n",
    "> This function is mainly aimed at users who cannot obtain the corresponding dataloader and help users debug.\n",
    ">\n",
    "> If you want to maximize the accuracy of quantized model, please pass in the original training/validation `Dataloader` as much as possible"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = next(iter(train_dataloader))\n",
    "\n",
    "optimizer.optimize(model=model,\n",
    "                   training_data=sample,\n",
    "                   thread_num=1,\n",
    "                   latency_sample_num=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Obtain specific model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You could call `get_best_model` method to obtain the best model under specific restrictions or without restrictions. Here we get the model with minimal latency when accuracy drop less than 5%."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "When accuracy drop less than 5%, the model with minimal latency is:  openvino + int8\n"
     ]
    }
   ],
   "source": [
    "optimizer.optimize(model=model,\n",
    "                   training_data=train_dataloader,\n",
    "                   validation_data=val_dataloader,\n",
    "                   metric=accuracy,\n",
    "                   direction=\"max\",\n",
    "                   thread_num=1,\n",
    "                   latency_sample_num=100)\n",
    "\n",
    "acc_model, option = optimizer.get_best_model(accuracy_criterion=0.05)\n",
    "print(\"When accuracy drop less than 5%, the model with minimal latency is: \", option)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📝 **Note**\n",
    "> \n",
    "> If you want to find the best model with `accuracy_criterion` paramter, make sure you have called `optimize` with validation data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you just want to obtain a specific model although it doesn't have the minimal latency, you could call `get_model` method and specify `method_name`. Here we take `openvino_fp32` as an example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "oepnvino_model = optimizer.get_model(method_name='openvino_fp32')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then you could use the obtained model for inference. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "with InferenceOptimizer.get_context(acc_model):\n",
    "    x = next(iter(train_dataloader))[0]\n",
    "    output = acc_model(x)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📝 **Note**\n",
    "> \n",
    "> For all Nano optimized models by `InferenceOptimizer.optimize`, you need to wrap the inference steps with an automatic context manager `InferenceOptimizer.get_context(model=...)` provided by Nano. You could refer to [here](https://bigdl.readthedocs.io/en/latest/doc/Nano/Howto/Inference/PyTorch/pytorch_context_manager.html) for more detailed usage of the context manager."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To export the obtained model, you could simply call `InferenceOptimizer.save` method and pass the path to it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dir = \"./best_model\"\n",
    "InferenceOptimizer.save(acc_model, save_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model files will be saved at `./best_model` directory. For different type of the obtained model, you only need to take the following files for further usage.\n",
    "\n",
    "- **OpenVINO**\n",
    "    \n",
    "    `ov_saved_model.bin`: Contains the weights and biases binary data of model\n",
    "    \n",
    "    `ov_saved_model.xml`: Model checkpoint for general use, describes model structure\n",
    "\n",
    "- **onnxruntime**\n",
    "\n",
    "    `onnx_saved_model.onnx`: Represents model checkpoint for general use, describes model structure\n",
    "    \n",
    "- **int8**\n",
    "\n",
    "    `best_model.pt`: Represents model optimized by Intel® Neural Compressor\n",
    "\n",
    "- **ipex | channel_last | jit | bf16**\n",
    "    \n",
    "    `ckpt.pt`: If `jit` in option, it stores model optimized using just-in-time compilation, otherwise, it stores original model weight by `torch.save(model.state_dict())`.\n",
    "\n",
    "- **Others**\n",
    "    \n",
    "    `saved_weight.pt`: Saved by `torch.save(model.state_dict())`."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📚 **Related Readings**\n",
    "> \n",
    "> - [How to install BigDL-Nano](https://bigdl.readthedocs.io/en/latest/doc/Nano/Overview/install.html)\n",
    "> - [How to enable automatic context management for PyTorch inference on Nano optimized models](https://bigdl.readthedocs.io/en/latest/doc/Nano/Howto/Inference/PyTorch/pytorch_context_manager.html)\n",
    "> - [How to save and load optimized ONNXRuntime model](https://bigdl.readthedocs.io/en/latest/doc/Nano/Howto/Inference/PyTorch/pytorch_save_and_load_onnx.html)\n",
    "> - [How to save and load optimized OpenVINO model](https://bigdl.readthedocs.io/en/latest/doc/Nano/Howto/Inference/PyTorch/pytorch_save_and_load_openvino.html)\n",
    "> - [How to save and load optimized JIT model](https://bigdl.readthedocs.io/en/latest/doc/Nano/Howto/Inference/PyTorch/pytorch_save_and_load_jit.html)\n",
    "> - [How to save and load optimized IPEX model](https://bigdl.readthedocs.io/en/latest/doc/Nano/Howto/Inference/PyTorch/pytorch_save_and_load_ipex.html)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nano-pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.15 (default, Nov 24 2022, 21:12:53) \n[GCC 11.2.0]"
  },
  "vscode": {
   "interpreter": {
    "hash": "8a5edab282632443219e051e4ade2d1d5bbc671c781051bf1437897cbdfea0f1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
